//+------------------------------------------------------------------+
//|                                                     linear_svm test.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "https://www.mql5.com/en/users/omegajoctan"
#property version   "1.00"
#property description "These inputs needs to have the same values as those of a script named GetDataforONNX"

#include <svm.mqh>
#resource "\\Files\\DualSVMONNX.onnx" as uchar SVMModel[]

#define  MAGIC_NUMBER 12112023

CMatrixutils matrix_utils;
CMetrics  metrics;

CLinearSVM *linear_svm;
CDualSVMONNX dual_svm;

#include <Trade\Trade.mqh>
#include <Trade\PositionInfo.mqh>

CTrade m_trade;
CPositionInfo m_position;

input uint bars = 1000;
input uint epochs_ = 1000;
input uint batch_size_ = 64;
input double alpha__ =0.1;
input double lambda_ = 0.01;

input int rsi_period = 13;
input int bb_period = 20;
input double bb_deviation = 2.0;
input int random_seed = 42;

enum type {LINEAR_SVM, DUAL_SVM};

input group "TRADE PARAMS";

input type svm_type = DUAL_SVM;
input int stop_loss = 3500;
input int take_profit = 1000;
input int slippage = 100;


int rsi_handle, 
    bb_handle;


int prev_bars=0;
    
vector min_v = {14.32424641,1.04674852,1.04799891,1.04392886};
vector max_v = {86.28263092,1.07385755,1.07907069,1.07267821};

//---

double lotsize;
int stops_level;
vectorf price_infof = {}; //price info float
vector price_infod = {}; //price info double 
int predicted_class = 0;
MqlTick ticks;
//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
//---    
    rsi_handle = iRSI(Symbol(),PERIOD_CURRENT, rsi_period, PRICE_CLOSE);
    bb_handle = iBands(Symbol(), PERIOD_CURRENT, bb_period, 0 , bb_deviation, PRICE_CLOSE);


//--- M-TRADE CONFIGS

   m_trade.SetExpertMagicNumber(MAGIC_NUMBER);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetTypeFillingBySymbol(Symbol());
   m_trade.SetMarginMode();
   
//---

    vector y_train,
           y_test;
     
     switch(svm_type)
       {
        case LINEAR_SVM:
           {
               linear_svm = new CLinearSVM(batch_size_, alpha__, epochs_, lambda_);
                             
               
               matrix dataset = GetTrainTestData<double>();
               matrix x_train,
                      x_test;
               
         
               matrix_utils.TrainTestSplitMatrices(dataset,x_train,y_train,x_test, y_test, 0.8, random_seed); //split the data into training and testing samples
               
               linear_svm.fit(x_train, y_train);
               
               vector train_pred = linear_svm.Predict(x_train), 
                      test_pred = linear_svm.Predict(x_test);
               
               Print("\n<<<<< Train Classification Report >>>>\n");
               metrics.confusion_matrix(y_train, train_pred);
               
               Print("\n<<<<< Test  Classification Report >>>>\n");
               metrics.confusion_matrix(y_test, test_pred);
               
          }
       break;
       
   //--- float values for dual SVM
        
       case DUAL_SVM:
           {
           
            vectorf max_vf = {}, min_vf = {}; //convertin the parameters into float type
            max_vf.Assign(max_v); 
            min_vf.Assign(min_v);
            
            dual_svm.LoadONNX(SVMModel, ONNX_DEFAULT, max_vf, min_vf);
            
            if (!MQLInfoInteger(MQL_TESTER)) //Do not train-test on strategy tester | this model was trained once no need
              {
                  matrixf datasetf = GetTrainTestData<float>();
                  matrixf x_trainf,
                          x_testf;
                  
                  vectorf y_trainf,
                          y_testf;
                        
            //---
             
                  matrix_utils.TrainTestSplitMatrices(datasetf,x_trainf,y_trainf,x_testf,y_testf,0.8,42); //split the data into training and testing samples
                  
                        
                  y_train.Assign(y_trainf);
                  y_test.Assign(y_testf);
                  
                  vector train_preds = dual_svm.Predict(x_trainf);
                  vector test_preds = dual_svm.Predict(x_testf);
                  
                  Print("\n<<<<< Train Classification Report >>>>\n");
                  metrics.confusion_matrix(y_train, train_preds);
                  
                  Print("\n<<<<< Test  Classification Report >>>>\n");
                  metrics.confusion_matrix(y_test, test_preds);
               }
               
          }
       break;
    }
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   if (CheckPointer(linear_svm) != POINTER_INVALID)
      delete (linear_svm);
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---

     predicted_class = 0;

     price_infof = GetTradingData<float>(); 
     price_infod.Assign(price_infof);
     
     SymbolInfoTick(Symbol(), ticks);
     lotsize = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
     stops_level = (int)SymbolInfoInteger(Symbol(), SYMBOL_TRADE_STOPS_LEVEL);

     
     if (isnewBar(PERIOD_CURRENT))
       {
         switch(svm_type)
           {
            case  DUAL_SVM:
               predicted_class = dual_svm.Predict(price_infof);     
              break;
            case LINEAR_SVM:
               predicted_class = linear_svm.Predict(price_infod);
              break;
           }
         
       
         if (predicted_class == 1) //predicted bullish
           {
             if (!PosExists(POSITION_TYPE_BUY))
               m_trade.Buy(lotsize, Symbol(), lotsize, ticks.bid-(stop_loss+stops_level)*Point(), ticks.bid+(take_profit+stops_level)*Point());
           }
        
        else if (predicted_class == -1)  //predicted bearish
          {
             if (!PosExists(POSITION_TYPE_SELL))
               m_trade.Sell(lotsize, Symbol(), lotsize, ticks.ask+(stop_loss+stops_level)*Point(), ticks.ask-(take_profit+stops_level)*Point());
          }
         else
           Comment("Failed to Get the predicted class | Err = ",GetLastError());
           
       }
     
  }
//+------------------------------------------------------------------+
//|   Getting data for Training and Testing the model                |
//+------------------------------------------------------------------+
template <typename T>
matrix<T> GetTrainTestData()
 {
   matrix<T> data(bars, 5);
   vector<T> v; //Temporary vector for storing Inidcator buffers
    
   v.CopyIndicatorBuffer(rsi_handle, 0, 0, bars);
   data.Col(v, 0);
   v.CopyIndicatorBuffer(bb_handle, 0, 0, bars);
   data.Col(v, 1);
   v.CopyIndicatorBuffer(bb_handle, 1, 0, bars);
   data.Col(v, 2);
   v.CopyIndicatorBuffer(bb_handle, 2, 0, bars);
   data.Col(v, 3);
   
   vector<T> open, close;
   open.CopyRates(Symbol(), PERIOD_CURRENT, COPY_RATES_OPEN, 0, bars);
   close.CopyRates(Symbol(), PERIOD_CURRENT, COPY_RATES_CLOSE, 0, bars);
   
   for (ulong i=0; i<v.Size(); i++) //preparing the independent variable
     data[i][4] = close[i] > open[i] ? 1 : -1; // if price closed above its opening thats bullish else bearish
     
   return data;  
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
template <typename T>
vector<T> GetTradingData()
 {
   vector<T> x_data(4); //Independent variables only
   vector<T> v; //Temporary vector for storing Inidcator buffers
    
//--- collect a single buffer value thats why there is value of 1 instead of bars

   v.CopyIndicatorBuffer(rsi_handle, 0, 1, 1); 
   x_data[0] = v[0];
   v.CopyIndicatorBuffer(bb_handle, 0, 1, 1);
   x_data[1] = v[0];
   v.CopyIndicatorBuffer(bb_handle, 1, 1, 1);
   x_data[2] = v[0];
   v.CopyIndicatorBuffer(bb_handle, 2, 1, 1);
   x_data[3] = v[0];
   
   
   return x_data;  
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool isnewBar(ENUM_TIMEFRAMES TF)
 {
   if (prev_bars == 0)
      prev_bars = Bars(Symbol(), TF);
      
   
   if (prev_bars != Bars(Symbol(), TF))
    { 
      prev_bars = Bars(Symbol(), TF);
      return true;
    }
    
  return false;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == MAGIC_NUMBER && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
